package org.wikipedia.miner.annotation.weighting;


import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

import org.apache.log4j.Logger;
import org.wikipedia.miner.annotation.Topic;
import org.wikipedia.miner.annotation.TopicDetector;
import org.wikipedia.miner.annotation.preprocessing.PreprocessedDocument;
import org.wikipedia.miner.comparison.ArticleComparer;
import org.wikipedia.miner.model.Article;
import org.wikipedia.miner.model.Wikipedia;
import org.wikipedia.miner.util.ProgressTracker;
import org.wikipedia.miner.util.RelatednessCache;
import org.wikipedia.miner.util.Result;
import org.wikipedia.miner.util.TopicIndexingSet;

import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.core.Instance;
import weka.core.Utils;
import weka.core.WekaException;
import org.dmilne.weka.wrapper.Dataset;
import org.dmilne.weka.wrapper.Decider;
import org.dmilne.weka.wrapper.DeciderBuilder;
import org.dmilne.weka.wrapper.InstanceBuilder;

public class TopicIndexer extends TopicWeighter {

	private Wikipedia wikipedia ;
	
	private enum Attributes {occurances,maxDisambigConfidence,avgDisambigConfidence,relatednessToContext,relatednessToOtherTopics,maxLinkProbability,avgLinkProbability,generality,firstOccurance,lastOccurance,spread} ;
	private Decider<Attributes, Boolean> decider ;
	private Dataset<Attributes, Boolean> dataset ;
	
	int candidatesConsidered = 0 ;
	
	public TopicIndexer(Wikipedia wikipedia) throws Exception {
		this.wikipedia = wikipedia ;
		
		decider = (Decider<Attributes, Boolean>) new DeciderBuilder<Attributes>("LinkDisambiguator", Attributes.class)
		.setDefaultAttributeTypeNumeric()
		.setClassAttributeTypeBoolean("isKeyTopic")
		.build();
	}
	
	public int getCandidatesConsidered() {
		return candidatesConsidered ;
	}
	
	public HashMap<Integer,Double> getTopicWeights(Collection<Topic> topics) throws Exception {
		
		if (!decider.isReady()) 
			throw new WekaException("You must build (or load) classifier first.") ;
		
		HashMap<Integer, Double> topicWeights = new HashMap<Integer, Double>() ;
	
		for (Topic topic: topics) {
		
			Instance i = getInstance(topic, null) ;
			
			double prob = decider.getDecisionDistribution(i).get(true) ;
			topicWeights.put(topic.getId(), prob) ;
			
			candidatesConsidered++ ;
		}
		
		return topicWeights ;
	}
	
	
	public void train(TopicIndexingSet trainingSet, String datasetName, TopicDetector td) throws Exception{

		dataset = decider.createNewDataset();
		
		ProgressTracker tracker = new ProgressTracker(trainingSet.size(), "training", TopicIndexer.class) ;
		
		
		for (TopicIndexingSet.Item i: trainingSet) {
			
			train(i, td) ;
			
			tracker.update() ;
		}
		
		weightTrainingInstances() ;
	}
	
	public Result<Integer> test(TopicIndexingSet trainingSet, TopicDetector td) throws Exception{

		if (!decider.isReady()) 
			throw new Exception("You must build (or load) classifier first.") ;
		
		double worstRecall = 1 ;
		double worstPrecision = 1 ;
		
		int docsTested = 0 ;
		int perfectRecall = 0 ;
		int perfectPrecision = 0 ;

		Result<Integer> r = new Result<Integer>() ;

		ProgressTracker tracker = new ProgressTracker(trainingSet.size(), "Testing", TopicIndexer.class) ;
		for (TopicIndexingSet.Item item:trainingSet) {
					
			docsTested ++ ;
			
			Result<Integer> ir = test(item, td) ;
			
			if (ir.getRecall() ==1) perfectRecall++ ;
			if (ir.getPrecision() == 1) perfectPrecision++ ;
			
			worstRecall = Math.min(worstRecall, ir.getRecall()) ;
			worstPrecision = Math.min(worstPrecision, ir.getPrecision()) ;
			
			r.addIntermediateResult(ir) ;
			
			
			tracker.update() ;
		}

		System.out.println("worstR:" + worstRecall + ", worstP:" + worstPrecision) ;
		System.out.println("tested:" + docsTested + ", perfectR:" + perfectRecall + ", perfectP:" + perfectPrecision) ;

		return r ;
	}

	
	/**
	 * Saves the training data generated by train() to the given file.
	 * The data is saved in WEKA's arff format. 
	 * 
	 * @param file the file to save the training data to
	 * @throws IOException if the file cannot be written to
	 */
	@SuppressWarnings("unchecked")
	public void saveTrainingData(File file) throws Exception {
		
		Logger.getLogger(TopicIndexer.class).info("saving training data") ;
		
		dataset.save(file) ;
	}

	/**
	 * Loads training data from the given file.
	 * The file must be a valid WEKA arff file. 
	 * 
	 * @param file the file to save the training data to
	 * @throws IOException if the file cannot be read.
	 * @throws Exception if the file does not contain valid training data.
	 */
	public void loadTrainingData(File file) throws Exception{
		Logger.getLogger(TopicIndexer.class).info("loading training data") ;
		
		dataset.load(file) ;
		weightTrainingInstances() ;
	}
	
	public void clearTrainingData() {
		dataset = null ;
	}

	/**
	 * Serializes the classifer and saves it to the given file.
	 * 
	 * @param file the file to save the classifier to
	 * @throws IOException if the file cannot be read
	 */
	public void saveClassifier(File file) throws IOException {
		Logger.getLogger(TopicIndexer.class).info("saving classifier") ;
		
		decider.save(file) ;
	}

	/**
	 * Loads the classifier from file
	 * 
	 * @param file 
	 * @throws Exception 
	 */
	public void loadClassifier(File file) throws Exception {
		Logger.getLogger(TopicIndexer.class).info("loading classifier") ;
		
		decider.load(file) ;
	}

	/**
	 * 
	 * 
	 * @param classifier
	 * @throws Exception
	 */
	public void buildClassifier(Classifier classifier) throws Exception {
		Logger.getLogger(TopicIndexer.class).info("building classifier") ;
		
		decider.train(classifier, dataset) ;
	}
	
	
	public void buildDefaultClassifier() throws Exception {
		Logger.getLogger(TopicIndexer.class).info("building classifier") ;
		
		Classifier classifier = new Bagging() ;
		classifier.setOptions(Utils.splitOptions("-P 10 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -U -M 2")) ;
		decider.train(classifier, dataset) ;
	}
	
	private void train(TopicIndexingSet.Item item, TopicDetector td) throws Exception{
		
		RelatednessCache rc = new RelatednessCache(new ArticleComparer(wikipedia)) ;
		
		Collection<Topic> topics = td.getTopics(item.getDocument(), rc) ;
		for (Topic topic: topics) 
			dataset.add(getInstance(topic, item.isTopic(topic))) ;
	}
	
	private Result<Integer> test(TopicIndexingSet.Item item, TopicDetector td) throws Exception{
		
		RelatednessCache rc = new RelatednessCache(new ArticleComparer(wikipedia)) ;
		
		Collection<Topic> topics = td.getTopics(item.getDocument(), rc) ;
		
		ArrayList<Topic> weightedTopics = this.getWeightedTopics(topics) ;
		
		HashSet<Integer> autoIds = new HashSet<Integer>() ;
		for (Topic topic: weightedTopics) {
			if (topic.getWeight() > 0.5) 
				autoIds.add(topic.getId()) ;			
			
		}
		
		Result<Integer> result = new Result<Integer>(autoIds,item.getTopicIds()) ;
		System.out.println(" - " + result) ;
		return result ;
	}
	
	
	
	
	private Instance getInstance(Topic topic, Boolean isKeyTopic) throws Exception {
		
		InstanceBuilder<Attributes,Boolean> ib = decider.getInstanceBuilder()
		.setAttribute(Attributes.occurances, topic.getOccurances())
		.setAttribute(Attributes.maxDisambigConfidence, topic.getMaxDisambigConfidence())
		.setAttribute(Attributes.avgDisambigConfidence, topic.getAverageDisambigConfidence())
		.setAttribute(Attributes.relatednessToContext, topic.getRelatednessToContext())
		.setAttribute(Attributes.relatednessToOtherTopics, topic.getRelatednessToOtherTopics())
		.setAttribute(Attributes.maxLinkProbability, topic.getMaxLinkProbability())
		.setAttribute(Attributes.avgLinkProbability, topic.getAverageLinkProbability())
		.setAttribute(Attributes.generality, topic.getGenerality())
		.setAttribute(Attributes.firstOccurance, topic.getFirstOccurance())
		.setAttribute(Attributes.lastOccurance, topic.getLastOccurance())
		.setAttribute(Attributes.spread, topic.getSpread()) ;
		
		if (isKeyTopic != null) 
			ib = ib.setClassAttribute(isKeyTopic) ;
		
		return ib.build() ;
	}
	
	//TODO: this should really be refactored as a separate filter
	@SuppressWarnings("unchecked")
	private void weightTrainingInstances() {

		double positiveInstances = 0 ;
		double negativeInstances = 0 ; 

		Enumeration<Instance> e = dataset.enumerateInstances() ;

		while (e.hasMoreElements()) {
			Instance i = (Instance) e.nextElement() ;

			double isValidSense = i.value(3) ;

			if (isValidSense == 0) 
				positiveInstances ++ ;
			else
				negativeInstances ++ ;
		}

		double p = (double) positiveInstances / (positiveInstances + negativeInstances) ;

		e = dataset.enumerateInstances() ;

		while (e.hasMoreElements()) {
			Instance i = (Instance) e.nextElement() ;

			double isValidSense = i.value(3) ;

			if (isValidSense == 0) 
				i.setWeight(0.5 * (1.0/p)) ;
			else
				i.setWeight(0.5 * (1.0/(1-p))) ;
		}

	}


	
}
